import subprocess
import mlperf_loadgen as lg
import argparse
import os
import sys

import sys
from backend import get_SUT

sys.path.insert(0, os.getcwd())


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--scenario",
        choices=["SingleStream", "Offline", "Server", "MultiStream"],
        default="Offline",
        help="Scenario",
    )
    parser.add_argument("--dataset-path", default="./data/cnn_eval.json", help="")
    parser.add_argument("--accuracy", action="store_true", help="enable accuracy pass")
    parser.add_argument(
        "--mlperf_conf", default="mlperf.conf", help="mlperf rules config"
    )
    parser.add_argument(
        "--user_conf",
        default="user.conf",
        help="user config for user LoadGen settings such as target QPS",
    )
    parser.add_argument(
        "--max_examples",
        type=int,
        default=13368,
        help="Maximum number of examples to consider (not limited by default)",
    )
    parser.add_argument(
        "--make-vocab-size-divisible-by",
        type=int,
        default=128,
        help="Make the vocab size divisible by",
    )
    parser.add_argument(
        "--tensor-model-parallel-size",
        type=int,
        default=8,
        help="Degree of tensor model parallelism.",
    )
    parser.add_argument(
        "--tokenizer-model",
        default="./data/c4_en_301_5Mexp2_spm.model",
        help="Path to tokenizer model",
    )
    args = parser.parse_args()
    return args


scenario_map = {
    "SingleStream": lg.TestScenario.SingleStream,
    "Offline": lg.TestScenario.Offline,
    "Server": lg.TestScenario.Server,
    "MultiStream": lg.TestScenario.MultiStream,
}


def main():
    args = get_args()
    os.environ["RANK"] = "0"
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"

    sut = get_SUT(
        scenario=args.scenario,
        dataset_path=args.dataset_path,
        max_examples=args.max_examples,
        args=args,
    )

    settings = lg.TestSettings()
    settings.scenario = scenario_map[args.scenario]
    # Need to update the conf
    settings.FromConfig(args.mlperf_conf, "gpt3", args.scenario)
    settings.FromConfig(args.user_conf, "gpt3", args.scenario)

    if args.accuracy:
        settings.mode = lg.TestMode.AccuracyOnly
    else:
        settings.mode = lg.TestMode.PerformanceOnly
    log_path = os.environ.get("LOG_PATH")
    if not log_path:
        log_path = "build/logs"
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    log_output_settings = lg.LogOutputSettings()
    log_output_settings.outdir = log_path
    log_output_settings.copy_summary_to_stdout = True
    log_settings = lg.LogSettings()
    log_settings.log_output = log_output_settings
    log_settings.enable_trace = True

    lg.StartTestWithLogSettings(sut.sut, sut.qsl, settings, log_settings)
    print("Test Done!")

    print("Destroying SUT...")
    lg.DestroySUT(sut.sut)

    print("Destroying QSL...")
    lg.DestroyQSL(sut.qsl)


if __name__ == "__main__":
    main()
